{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Relay Pass Instrument"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tvm\n",
    "from tvm import relay\n",
    "\n",
    "@tvm.instrument.pass_instrument\n",
    "class PassCounter:\n",
    "    def __init__(self):\n",
    "        # Just setting a garbage value to test set_up callback\n",
    "        self.counts = 1234\n",
    "\n",
    "    def enter_pass_ctx(self):\n",
    "        self.counts = 0\n",
    "\n",
    "    def exit_pass_ctx(self):\n",
    "        self.counts = 0\n",
    "\n",
    "    def run_before_pass(self, module, info):\n",
    "        self.counts += 1\n",
    "\n",
    "    def get_counts(self):\n",
    "        return self.counts\n",
    "\n",
    "\n",
    "def test_print_debug_callback():\n",
    "    shape = (1, 2, 3)\n",
    "    tp = relay.TensorType(shape, \"float32\")\n",
    "    x = relay.var(\"x\", tp)\n",
    "    y = relay.add(x, x)\n",
    "    y = relay.multiply(y, relay.const(2, \"float32\"))\n",
    "    func = relay.Function([x], y)\n",
    "\n",
    "    seq = tvm.transform.Sequential(\n",
    "        [\n",
    "            relay.transform.InferType(),\n",
    "            relay.transform.FoldConstant(),\n",
    "            relay.transform.DeadCodeElimination(),\n",
    "        ]\n",
    "    )\n",
    "\n",
    "    mod = tvm.IRModule({\"main\": func})\n",
    "\n",
    "    pass_counter = PassCounter()\n",
    "    with tvm.transform.PassContext(opt_level=3, instruments=[pass_counter]):\n",
    "        # Should be reseted when entering pass context\n",
    "        assert pass_counter.get_counts() == 0\n",
    "        mod = seq(mod)\n",
    "\n",
    "        # TODO(@jroesch): when we remove new fn pass behavior we need to remove\n",
    "        # change this back to match correct behavior\n",
    "        assert pass_counter.get_counts() == 6\n",
    "\n",
    "    # Should be cleanned up after exiting pass context\n",
    "    assert pass_counter.get_counts() == 0\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.8.13 ('py38': conda)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "28558e8daad512806f5c536a1a04c119185f99f65b79002708a12162d02a79c7"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
